using Base: method_argnames, kwarg_decl
using OrderedCollections: OrderedSet
using InteractiveUtils: @code_lowered

abstract type AbstractExtendedFunction end

struct ExtendedFunction <: AbstractExtendedFunction
    f::Union{Function, AbstractExtendedFunction}
    name::String
    inputs::OrderedSet{String}
    outputs::OrderedSet{String}

    function ExtendedFunction(f)
        if f isa AbstractExtendedFunction
            return new(f.f, f.name, f.inputs, f.outputs)
        else
            return new(f, metadata(f)...)
        end
    end
end

function input_list(f::Function)
    last_method = last(methods(f))
    params = method_argnames(last_method)[2:end]
    kwargs = kwarg_decl(last_method)
    args = string.(vcat(params, kwargs))
    return OrderedSet(args)
end

function output_list(f::Function)
    num_params = length(method_argnames(last(methods(f))))+length(kwarg_decl(last(methods(f))))-1
    ast = string(@code_lowered Base.bodyfunction(last(methods(f)))(ones(num_params)..., f))
    out_tuple = collect(eachmatch(r"Core\.tuple\((.+)\)\n", ast))[end].captures[1]
    return OrderedSet(string.(strip.(split(out_tuple, ","))))
end

function metadata(f::Function)
    name = String(nameof(f))
    name = name[1]=='_' ? name[2:end] : name
    inputs = input_list(f)
    outputs = output_list(f)
    return name, inputs, outputs
end

Base.show(io::IO, f::AbstractExtendedFunction) = print(io, "<$(string(typeof(f)))($(f.name)): [$(join(f.inputs,", "))] -> [$(join(f.outputs,", "))]>")

function (f::AbstractExtendedFunction)(input_dict)
    input_dict = Dict(Symbol(k) => v for (k, v) ∈ input_dict if k ∈ f.inputs)
    return Dict(zip(f.outputs, Tuple(f.f(; input_dict...))))
end

function wrapped_call(f::AbstractExtendedFunction, input_dict; preprocess=nothing, postprocess=nothing)
    if !(preprocess isa Nothing)
        input_dict = Dict(Symbol(k) => preprocess(v) for (k, v) ∈ input_dict if k ∈ f.inputs)
    else
        input_dict = Dict(Symbol(k) => v for (k, v) ∈ input_dict if k in f.inputs)
    end

    output_dict = Dict(zip(f.outputs, Tuple(f.f(; input_dict...))))

    if !(postprocess isa Nothing)
        output_dict = Dict(k => postprocess(v) for (k, v) ∈ output_dict)
    end

    return output_dict
end

function differentiable(f::AbstractExtendedFunction, input_dict; h=1e-4, twosided=false)
    return DifferentiableExtendedFunction(f.f, f.name, f.inputs, f.outputs, input_dict; h=h, twosided=twosided)
end

hide_zero_values(d) = Dict(k => v for (k, v) ∈ d if !all(x -> isapprox(x...; atol=1e-8, rtol=1e-5), zip(zeros(length(v)), v)))

mutable struct DifferentiableExtendedFunction <: AbstractExtendedFunction
    f
    name
    inputs
    outputs
    input_dict
    output_dict
    h
    default_twosided

    function DifferentiableExtendedFunction(f, name, inputs, outputs, input_dict; h=1e-4, twosided=false)
        return new(f, name, inputs, outputs, input_dict, nothing, h, twosided)
    end
end

function diff(f::DifferentiableExtendedFunction, shock_dict; h=nothing, hide_zeros=false, twosided=nothing)
    isnothing(twosided) && (twosided = f.default_twosided)
    twosided || (return diff1(f, shock_dict; h=h, hide_zeros=hide_zeros))
    twosided && (return diff2(f, shock_dict; h=h, hide_zeros=hide_zeros))
end

function diff1(f::DifferentiableExtendedFunction, shock_dict; h=nothing, hide_zeros=false)
    isnothing(h) && (h = f.h)
    isnothing(f.output_dict) && (f.output_dict = f(f.input_dict))

    shocked_input_dict = Dict(f.input_dict..., Dict(k => f.input_dict[k] .+ h .* shock for (k, shock) ∈ shock_dict if k ∈ keys(f.input_dict))...)

    shocked_output_dict = f(shocked_input_dict)

    derivative_dict = Dict(k => (shocked_output_dict[k] - f.output_dict[k])/h for k ∈ keys(f.output_dict))

    hide_zeros && (derivative_dict = hide_zero_values(derivative_dict))

    return derivative_dict
end

function diff2(f::DifferentiableExtendedFunction, shock_dict; h=nothing, hide_zeros=false)
    isnothing(h) && (h = f.h)

    shocked_input_dict_up = Dict(f.input_dict..., Dict(k => f.input_dict[k] + h * shock for (k, shock) ∈ shock_dict if k ∈ keys(input_dict))...)
    shocked_input_dict_dn = Dict(f.input_dict..., Dict(k => f.input_dict[k] - h * shock for (k, shock) ∈ shock_dict if k ∈ keys(input_dict))...)

    shocked_output_dict_up = f(shocked_input_dict_up)
    shocked_output_dict_dn = f(shocked_input_dict_dn)

    derivative_dict = Dict(k => (shocked_output_dict_up[k] - shocked_output_dict_dn[k])/(2*h) for k ∈ shocked_output_dict_dn)

    hide_zeros && (derivative_dict = hide_zero_values(derivative_dict))

    return derivative_dict
end

abstract type AbstractCombinedExtendedFunction <: AbstractExtendedFunction end

struct CombinedExtendedFunction <: AbstractCombinedExtendedFunction
    dag
    inputs
    outputs
    functions
    name

    function CombinedExtendedFunction(fs; name=nothing)
        _dag = DAG([ExtendedFunction(f) for f ∈ fs])
        _inputs = _dag.inputs
        _outputs = _dag.outputs
        _functions = OrderedDict(b.name => b for b ∈ _dag.blocks)

        if isnothing(name)
            names = collect(keys(_functions))
            if length(names) == 1
                _name = names[1]
            else
                _name = "$(names[1])_$(names[end])"
            end
        end

        return new(_dag, _inputs, _outputs, _functions, _name)
    end
end

function (f::AbstractCombinedExtendedFunction)(input_dict; outputs=nothing)
    functions_to_visit = collect(values(f.functions))
    if !isnothing(outputs)
        functions_to_visit = [functions_to_visit[i] for i ∈ visit_from_outputs(f.dag, outputs)]
    end

    results = copy(input_dict)

    for g ∈ functions_to_visit
        results = merge(results, g(results))
    end

    if !isnothing(outputs)
        return Dict(k => results[k] for k ∈ outputs)
    else
        return results
    end
end

function filter(f::AbstractCombinedExtendedFunction, function_list, inputs; outputs=nothing)
    nums_to_visit = visit_from_inputs(f.dag, inputs)

    if !isnothing(outputs)
        nums_to_visit = nums_to_visit ∩ visit_from_outputs(f.dag, outputs)
    end

    return [function_list[n] for n ∈ nums_to_visit]
end

function call_on_deviations(f::AbstractCombinedExtendedFunction, ss, dev_dict; outputs=nothing)
    functions_to_visit = filter(f, collect(values(f.functions)), dev_dict; outputs=outputs)
    
    results = Dict()
    input_dict = Dict(ss..., dev_dict...)

    for f ∈ functions_to_visit
        out = f(input_dict)
        merge!(results, out)
        merge!(input_dict, out)
    end

    if !isnothing(outputs)
        return Dict(k => v for (k, v) ∈ results if k ∈ outputs)
    else
        return results
    end
end

function wrapped_call(f::AbstractCombinedExtendedFunction, input_dict; preprocess=nothing, postprocess=nothing)
    return error("Not Implemented")
end

add(f::AbstractCombinedExtendedFunction, g::Union{Function, AbstractExtendedFunction}) = CombinedExtendedFunction([collect(values(f.functions)); g])
add(f::AbstractCombinedExtendedFunction, g) = CombinedExtendedFunction([collect(values(f.functions))..., collect(g)...])

remove(f::AbstractCombinedExtendedFunction, name::AbstractString) = CombinedExtendedFunction([v for (k, v) ∈ f.functions if k != name])
remove(f::AbstractCombinedExtendedFunction, name) = CombinedExtendedFunction([v for (k, v) ∈ f.functions if k ∉ name])

children(f::AbstractCombinedExtendedFunction) = OrderedSet(f.functions)

function differentiable(f::AbstractCombinedExtendedFunction, input_dict; h=1e-5, twosided=false)
    return DifferentiableCombinedExtendedFunction(f.functions, f.dag, f.name, f.inputs, f.outputs, input_dict; h=h, twosided=twosided)
end

struct DifferentiableCombinedExtendedFunction <: AbstractCombinedExtendedFunction
    dag
    name
    inputs
    outputs
    diff_functions
    default_twosided

    function DifferentiableCombinedExtendedFunction(functions, dag, name, inputs, outputs, input_dict; h=1e-5, twosided=false)
        diff_functions = OrderedDict()
        for (k, f) ∈ functions
            diff_functions[k] = differentiable(f, input_dict; h=h)
        end
        return new(dag, name, inputs, outputs, diff_functions, twosided)
    end
end

function diff(f::DifferentiableCombinedExtendedFunction, shock_dict; h=nothing, outputs=nothing, hide_zeros=false, twosided=false)
    isnothing(twosided) && (twosided = f.default_twosided)
    twosided || (return diff1(f, shock_dict; h=h, outputs=outputs, hide_zeros=hide_zeros))
    twosided && (return diff2(f, shock_dict; h=h, outputs=outputs, hide_zeros=hide_zeros))
end

function diff1(f::DifferentiableCombinedExtendedFunction, shock_dict; h=nothing, outputs=nothing, hide_zeros=false)
    functions_to_visit = filter(f, collect(values(f.diff_functions)), shock_dict; outputs=outputs)

    shock_dict = copy(shock_dict)
    results = Dict()
    
    for g ∈ functions_to_visit
        out = diff1(g, shock_dict; h=h, hide_zeros=hide_zeros)
        merge!(results, out)
        shock_dict = merge(shock_dict, out)
    end

    if !isnothing(outputs)
        return Dict(k => v for (k, v) ∈ results if k ∈ outputs)
    else
        return results
    end
end

function diff2(f::DifferentiableCombinedExtendedFunction, shock_dict; h=nothing, outputs=nothing, hide_zeros=false)
    functions_to_visit = filter(f, collect(values(f.diff_functions)), shock_dict; outputs=outputs)

    shock_dict = copy(shock_dict)
    results = Dict()
    
    for g ∈ functions_to_visit
        out = diff2(g, shock_dict; h=h, hide_zeros=hide_zeros)
        merge!(results, out)
        merge!(shock_dict, out)
    end

    if !isnothing(outputs)
        return Dict(k => v for (k, v) ∈ results if k ∈ outputs)
    else
        return results
    end
end
